from collections import defaultdict, namedtuple
import numpy as np
import torch
from utility.file_utils import get_subdirs, get_files
import random
import collections
import os
import os.path as osp
import json
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from tqdm import tqdm
import cv2

normalize_threshold = 5e-2

def create_sample_indices(
    episode_ends: np.ndarray,
    sequence_length: int,
    pad_before: int = 0,
    pad_after: int = 0,
):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i - 1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        for idx in range(min_start, max_start + 1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx + sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx + start_idx)
            end_offset = (idx + sequence_length + start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append(
                [buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx]
            )
    indices = np.array(indices)
    return indices


def sample_sequence(
    train_data,
    sequence_length,
    buffer_start_idx,
    buffer_end_idx,
    sample_start_idx,
    sample_end_idx,
):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:], dtype=input_arr.dtype
            )
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result


# normalize data
def get_data_stats(data):
    data = data.reshape(-1, data.shape[-1])
    stats = {"min": np.min(data, axis=0), "max": np.max(data, axis=0)}
    return stats


def normalize_data(data, stats):
    # nomalize to [0,1]
    ndata = data.copy()
    for i in range(ndata.shape[1]):
        if stats["max"][i] - stats["min"][i] > normalize_threshold:
            ndata[:, i] = (data[:, i] - stats["min"][i]) / (
                stats["max"][i] - stats["min"][i]
            )
            ndata[:, i] = ndata[:, i] * 2 - 1
    return ndata


def unnormalize_data(ndata, stats):
    data = ndata.copy()
    for i in range(ndata.shape[1]):
        if stats["max"][i] - stats["min"][i] > normalize_threshold:
            ndata[:, i] = (ndata[:, i] + 1) / 2
            data[:, i] = (
                ndata[:, i] * (stats["max"][i] - stats["min"][i]) + stats["min"][i]
            )
    return data


class RobomimicBCDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        data_dirs,
        proto_dirs,
        pred_horizon,
        obs_horizon,
        action_horizon,
        resize_shape=None,
        proto_horizon=None,
        raw_representation=False,
        softmax_prototype=False,
        prototype=False,
        one_hot_prototype=False,
        prototype_snap=False,
        mask=None,
        obs_image_based=False,
        unnormal_list=[],
        pipeline=None,
        verbose=False,
        seed=0,
    ):
        """
        Support 1) raw representation 2) softmax prototype 3) prototype 4) one-hot prototype
        """
        self.verbose = verbose
        self.resize_shape = resize_shape
        if mask is not None:
            self.mask = np.load(mask).tolist()
        else:
            self.mask = None

        self.seed = seed
        self.set_seed(self.seed)
        self.raw_representation = raw_representation
        self.softmax_prototype = softmax_prototype
        self.prototype = prototype
        self.one_hot_prototype = one_hot_prototype
        self.obs_image_based = obs_image_based
        self.prototype_snap = prototype_snap
        self.pipeline = pipeline
        self.unnormal_list = unnormal_list

        self.data_dirs = data_dirs
        self.proto_dirs = proto_dirs
        self._build_dir_tree()

        train_data = defaultdict(list)
        self.load_data(train_data)

        episode_ends = []
        for eps_action_data in train_data["actions"]:
            episode_ends.append(len(eps_action_data))

        for k, v in train_data.items():
            train_data[k] = np.concatenate(v)

        print(f"training data len {len(train_data['actions'])}")

        episode_ends = np.cumsum(episode_ends)
        self.episode_ends = episode_ends

        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            pad_before=obs_horizon - 1,
            pad_after=action_horizon - 1,
        )

        stats = dict()
        for key, data in train_data.items():
            if key == "images" or key in self.unnormal_list:
                pass
            else:
                stats[key] = get_data_stats(data)

            if key == "images" or key in self.unnormal_list:
                pass
            else:
                train_data[key] = normalize_data(data, stats[key])

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon
        if proto_horizon is None:
            self.proto_horizon = obs_horizon
        else:
            self.proto_horizon = proto_horizon

    def set_seed(self, seed):
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

    def _build_dir_tree(self):
        """Build a dict of indices for iterating over the dataset."""
        self._dir_tree = collections.OrderedDict()
        for i, path in enumerate(self.data_dirs):
            vids = get_subdirs(
                path,
                nonempty=False,
                sort_numerical=True,
            )
            if vids:
                vids = np.array(vids)
                vids_length = len(vids)
                self._dir_tree[path] = vids

    def load_action_and_to_tensor(self, vid):
        action_path = os.path.join(vid, "action.npy")
        action_data = np.load(action_path)
        action_data = np.array(action_data, dtype=np.float32)
        return action_data

    def load_state_and_to_tensor(self, vid):
        state_path = os.path.join(vid, "state.npy")
        state_data = np.load(state_path)
        return state_data

    def load_proto_and_to_tensor(self, vid):
        v = os.path.basename(os.path.normpath(vid))
        load_mode = 'emb'
        if load_mode == 'emb':
            proto_path = osp.join(self.proto_dirs, f'{v}_emb.npy')
            proto_data = np.load(proto_path)
            proto_data = np.append(proto_data, np.zeros((proto_data.shape[0], 3)), axis=1)
            proto_data = np.array(proto_data, dtype=np.float32)  # (T,D)

        if load_mode == 'label':
            proto_path = osp.join(self.proto_dirs, f'{v}_label.npy')
            proto_data = np.load(proto_path)
            proto_data = np.expand_dims(proto_data, axis=1)
            proto_data = np.append(proto_data, np.zeros((proto_data.shape[0], 3)), axis=1)
            proto_data = np.array(proto_data, dtype=np.float32)
        return proto_data


    def load_images(self, vid):
        images = []
        filenames = sorted(
            [f for f in os.listdir(Path(vid)) if f.endswith(".png")],
            key=lambda x: int(os.path.splitext(x)[0]),
        )

        for filename in filenames:
            img = Image.open(os.path.join(vid, filename))
            img_arr = np.array(img)
            if self.resize_shape is not None:
                img_arr = cv2.resize(img_arr, self.resize_shape)
            images.append(img_arr)
        images_arr = np.array(images)
        assert images_arr.dtype == np.uint8
        return images_arr

    def transform_images(self, images_arr):
        images_arr = images_arr.astype(np.float32)
        images_tensor = np.transpose(images_arr, (0, 3, 1, 2)) / 255.0
        return images_tensor

    def load_data(self, train_data):
        vid = list(self._dir_tree.values())[0]
        
        print("loading data")
        for j, v in tqdm(enumerate(vid), desc="Loading data", disable=not self.verbose):
            if self.obs_image_based:
                images = self.load_images(v)
                train_data["images"].append(images)

            train_data["obs"].append(self.load_state_and_to_tensor(v))
            if self.prototype_snap:
                proto_data, proto_snap = self.load_proto_and_to_tensor(v)
                train_data["proto_snap"].append(proto_snap)
            else:
                proto_data = self.load_proto_and_to_tensor(v)

            train_data["protos"].append(proto_data)
            train_data["actions"].append(self.load_action_and_to_tensor(v))

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        (
            buffer_start_idx,
            buffer_end_idx,
            sample_start_idx,
            sample_end_idx,
        ) = self.indices[idx]

        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx,
        )

        nsample["obs"] = nsample["obs"][: self.obs_horizon, :]
        if self.prototype_snap:
            nsample["protos"] = nsample["protos"][: self.obs_horizon, :]
            nsample["protos"] = nsample["protos"][-1:, :]
            nsample["proto_snap"] = nsample["proto_snap"][-1:, :]
        else:
            nsample["protos"] = nsample["protos"][: self.obs_horizon, :]
            nsample["protos"] = nsample["protos"][-self.proto_horizon :, :]

        if self.obs_image_based:
            nsample["images"] = self.transform_images(nsample["images"])
            nsample["images"] = nsample["images"][: self.obs_horizon, :]
            nsample["obs"] = nsample["obs"][: self.obs_horizon, :9]

        return nsample
